import asyncio
import os

import numpy as np
import matplotlib.pyplot as plt

from numpy.polynomial.polynomial import polyfit

from py_pli.pylib import VUnits
from py_pli.pylib import GlobalVar
from py_pli.pylib import send_gc_event

from config_enum import hal_enum as hal_config

from virtualunits.HAL import HAL
from virtualunits.meas_seq_generator import meas_seq_generator
from virtualunits.meas_seq_generator import OutputSignal
from virtualunits.meas_seq_generator import MeasurementChannel

from urpc_enum.measurementparameter import MeasurementParameter

from fleming.common.firmware_util import *

hal_unit: HAL = VUnits.instance.hal
meas_unit = hal_unit.measurementUnit

fmb_endpoint = get_node_endpoint('fmb')
eef_endpoint = get_node_endpoint('eef')
meas_endpoint = get_measurement_endpoint()

report_path = hal_unit.get_config(hal_config.Application.GCReportPath)
os.makedirs(report_path, exist_ok=True)

laser_current_fs = 1.335            # empirical value, since the gain of the amplifier is not 10 as expected
high_voltage_fs = 2500.0
high_voltage_monitor_fs = 1500.0
tec_voltage_fs = 10.08              # 10.08 or 10.26 with Fluke 115 in series
tec_current_fs = 6.25


async def usl_test():

    GlobalVar.set_stop_gc(False)
    await send_gc_msg(f"Starting USL Test")

    await start_firmware('fmb')
    await fmb_endpoint.SetDigitalOutput(FMBDigitalOutput.PSUON, 0)   # Base Tester PSUON is active low
    try:
        await asyncio.sleep(0.1)
        await start_firmware('eef')
        if GlobalVar.get_stop_gc():
            return f"usl_test stopped by user"

        usl_test_error = False

        # LASER CURRENT CHECK ##########################################################################################
        try:
            await eef_endpoint.SetDigitalOutput(EEFDigitalOutput.USLTESTERSWITCH3, 1)
            await meas_endpoint.SetParameter(MeasurementParameter.HTSAlphaLaserEnable, 1)

            # LASER POWER SCAN
            PyLogger.logger.info(f"HTS Alpha Laser Power Scan:")
            PyLogger.logger.info(f"power [FS] ; current [A]")

            power_range = np.arange(0.0, (1.0 + 1e-6), 0.05).round(2)
            laser_current = np.zeros_like(power_range)
            for i, power in enumerate(power_range):
                if GlobalVar.get_stop_gc():
                    return f"usl_test stopped by user"

                await meas_endpoint.SetParameter(MeasurementParameter.HTSAlphaLaserPower, power)
                await set_hts_alpha_excitation(enable=True)
                await asyncio.sleep(0.1)
                laser_current[i] = (await eef_endpoint.GetAnalogInput(EEFAnalogInput.USLTESTERVIN1))[0] * laser_current_fs

                PyLogger.logger.info(f"{power:10.2f} ; {laser_current[i]:.3f}")

            await plot_laser_power_scan(power_range, laser_current)

            laser_current_max = np.max(laser_current)
            laser_current_slope = 0.0
            laser_current_intercept = 0.0
            laser_current_r_squared = 0.0

            linear_range = np.logical_and((laser_current > (laser_current_max * 0.2)), (laser_current <= (laser_current_max * 0.8)))
            if np.any(linear_range):
                laser_current_intercept, laser_current_slope = polyfit(power_range[linear_range], laser_current[linear_range], deg=1)
                laser_current_r_squared = np.corrcoef(power_range[linear_range], laser_current[linear_range])[0,1] ** 2

            laser_current_max_error = abs((laser_current_max - 0.900) / 0.900)
            laser_current_slope_error = abs((laser_current_slope - 1.006) / 1.006)

            PyLogger.logger.info(f"laser_current_max: {laser_current_max:.3f}, laser_current_max_error: {laser_current_max_error:.2%}")
            PyLogger.logger.info(f"laser_current_slope: {laser_current_slope:.3f}, laser_current_slope_error: {laser_current_slope_error:.2%}")
            PyLogger.logger.info(f"laser_current_intercept: {laser_current_intercept:.3f}")
            PyLogger.logger.info(f"laser_current_r_squared: {laser_current_r_squared:.6f}")

            if laser_current_max < 0.750 or laser_current_max > 1.050 or laser_current_slope_error > 0.1 or laser_current_r_squared < 0.99998:
                usl_test_error = True
                await send_gc_msg(f"FAILED  - HTS Alpha Laser Power Scan")
            else:
                await send_gc_msg(f"PASSED  - HTS Alpha Laser Power Scan")

            await meas_endpoint.SetParameter(MeasurementParameter.HTSAlphaLaserPower, 1.0)

            # EXCITATION OFF
            await set_hts_alpha_excitation(enable=False)
            await asyncio.sleep(0.1)
            laser_current = (await eef_endpoint.GetAnalogInput(EEFAnalogInput.USLTESTERVIN1))[0] * laser_current_fs
            temp_error = (await eef_endpoint.GetDigitalInput(EEFDigitalInput.HTSALPHATEMPERROR))[0]

            PyLogger.logger.info(f"Laser Excitation: OFF, laser_current: {laser_current:.3f} A, temp_error: {temp_error}")

            if laser_current > 0.01:
                usl_test_error = True
                await send_gc_msg(f"FAILED  - HTS Alpha Laser Excitation OFF")
            else:
                await send_gc_msg(f"PASSED  - HTS Alpha Laser Excitation OFF")

            if GlobalVar.get_stop_gc():
                return f"usl_test stopped by user"

            # EXCITATION ON / TEMPERATURE ERROR OFF
            await set_hts_alpha_excitation(enable=True)
            await asyncio.sleep(0.1)
            laser_current = (await eef_endpoint.GetAnalogInput(EEFAnalogInput.USLTESTERVIN1))[0] * laser_current_fs
            temp_error = (await eef_endpoint.GetDigitalInput(EEFDigitalInput.HTSALPHATEMPERROR))[0]

            PyLogger.logger.info(f"Laser Excitation: ON, laser_current: {laser_current:.3f} A, temp_error: {temp_error}")

            if laser_current < 0.750 or laser_current > 1.050 or temp_error != 0:
                usl_test_error = True
                await send_gc_msg(f"FAILED  - HTS Alpha Laser Excitation ON")
            else:
                await send_gc_msg(f"PASSED  - HTS Alpha Laser Excitation ON")

            if GlobalVar.get_stop_gc():
                return f"usl_test stopped by user"
            
            # TEMPERATURE ERROR ON
            await eef_endpoint.SetDigitalOutput(EEFDigitalOutput.USLTESTERSWITCH3, 0)
            await asyncio.sleep(0.1)
            laser_current = (await eef_endpoint.GetAnalogInput(EEFAnalogInput.USLTESTERVIN1))[0] * laser_current_fs
            temp_error = (await eef_endpoint.GetDigitalInput(EEFDigitalInput.HTSALPHATEMPERROR))[0]

            PyLogger.logger.info(f"Temperature Error: ON, laser_current: {laser_current:.3f} A, temp_error: {temp_error}")

            if laser_current > 0.01 or temp_error != 1:
                usl_test_error = True
                await send_gc_msg(f"FAILED  - HTS Alpha Laser Temperature Shutdown")
            else:
                await send_gc_msg(f"PASSED  - HTS Alpha Laser Temperature Shutdown")

            if GlobalVar.get_stop_gc():
                return f"usl_test stopped by user"

            await eef_endpoint.SetDigitalOutput(EEFDigitalOutput.USLTESTERSWITCH3, 1)

        finally:
            await meas_endpoint.SetParameter(MeasurementParameter.HTSAlphaLaserPower, 0.0)
            await meas_endpoint.SetParameter(MeasurementParameter.HTSAlphaLaserEnable, 0)

        # PHOTODIODE FEEDBACK CHECK ####################################################################################

        await eef_endpoint.SetDigitalOutput(EEFDigitalOutput.USLTESTERSWITCH1, 0)
        await eef_endpoint.SetDigitalOutput(EEFDigitalOutput.USLTESTERSWITCH2, 0)
        await asyncio.sleep(0.1)
        photo_current = (await eef_endpoint.GetAnalogInput(EEFAnalogInput.HTSALPHAPHOTODIODE))[0]

        PyLogger.logger.info(f"HTS Alpha Photodiode S1=0, S2=0, photo_current: {photo_current:.3f}")

        await send_gc_msg(f"SKIPPED - HTS Alpha Photodiode Test 1")

        if GlobalVar.get_stop_gc():
            return f"usl_test stopped by user"

        await eef_endpoint.SetDigitalOutput(EEFDigitalOutput.USLTESTERSWITCH1, 1)
        await asyncio.sleep(0.1)
        photo_current = (await eef_endpoint.GetAnalogInput(EEFAnalogInput.HTSALPHAPHOTODIODE))[0]

        PyLogger.logger.info(f"HTS Alpha Photodiode S2=0, S1=1, photo_current: {photo_current:.3f}")

        await send_gc_msg(f"SKIPPED - HTS Alpha Photodiode Test 2")

        if GlobalVar.get_stop_gc():
            return f"usl_test stopped by user"

        await eef_endpoint.SetDigitalOutput(EEFDigitalOutput.USLTESTERSWITCH2, 1)
        await asyncio.sleep(0.1)
        photo_current = (await eef_endpoint.GetAnalogInput(EEFAnalogInput.HTSALPHAPHOTODIODE))[0]

        PyLogger.logger.info(f"HTS Alpha Photodiode S2=1, S1=1, photo_current: {photo_current:.3f}")

        await send_gc_msg(f"SKIPPED - HTS Alpha Photodiode Test 3")

        if GlobalVar.get_stop_gc():
            return f"usl_test stopped by user"

        await eef_endpoint.SetDigitalOutput(EEFDigitalOutput.USLTESTERSWITCH1, 0)
        await eef_endpoint.SetDigitalOutput(EEFDigitalOutput.USLTESTERSWITCH2, 0)

        # HIGH VOLTAGE CHECK ###########################################################################################
        try:
            await meas_endpoint.SetParameter(MeasurementParameter.PMTUSLUMHighVoltageEnable, 1)
            await asyncio.sleep(1.0)

            # HIGH VOLTAGE = 0 V
            await meas_endpoint.SetParameter(MeasurementParameter.PMTUSLUMHighVoltageSetting, 0.0)
            await asyncio.sleep(0.2)
            high_voltage = (await eef_endpoint.GetAnalogInput(EEFAnalogInput.USLTESTERVIN2))[0] * high_voltage_fs
            high_voltage_monitor = (await meas_endpoint.GetParameter(MeasurementParameter.PMTUSLUMHighVoltageMonitor))[0] * high_voltage_monitor_fs

            PyLogger.logger.info(f"USLUM High Voltage: 0.0, high_voltage: {high_voltage:.0f} V, high_voltage_monitor: {high_voltage_monitor:.0f} V")

            if high_voltage > 20 or high_voltage_monitor > 20:
                usl_test_error = True
                await send_gc_msg(f"FAILED  - USLUM High Voltage = 0 V")
            else:
                await send_gc_msg(f"PASSED  - USLUM High Voltage = 0 V")

            if GlobalVar.get_stop_gc():
                return f"usl_test stopped by user"

            # HIGH VOLTAGE = 750 V
            await meas_endpoint.SetParameter(MeasurementParameter.PMTUSLUMHighVoltageSetting, 0.5)
            await asyncio.sleep(0.2)
            high_voltage = (await eef_endpoint.GetAnalogInput(EEFAnalogInput.USLTESTERVIN2))[0] * high_voltage_fs
            high_voltage_monitor = (await meas_endpoint.GetParameter(MeasurementParameter.PMTUSLUMHighVoltageMonitor))[0] * high_voltage_monitor_fs

            PyLogger.logger.info(f"USLUM High Voltage: 0.5, high_voltage: {high_voltage:.0f} V, high_voltage_monitor: {high_voltage_monitor:.0f} V")

            if high_voltage < 675 or high_voltage > 825 or high_voltage_monitor < 675 or high_voltage_monitor > 825:
                usl_test_error = True
                await send_gc_msg(f"FAILED  - USLUM High Voltage = 750 V")
            else:
                await send_gc_msg(f"PASSED  - USLUM High Voltage = 750 V")

            if GlobalVar.get_stop_gc():
                return f"usl_test stopped by user"

            # HIGH VOLTAGE = 1500 V
            await meas_endpoint.SetParameter(MeasurementParameter.PMTUSLUMHighVoltageSetting, 1.0)
            await asyncio.sleep(0.2)
            high_voltage = (await eef_endpoint.GetAnalogInput(EEFAnalogInput.USLTESTERVIN2))[0] * high_voltage_fs
            high_voltage_monitor = (await meas_endpoint.GetParameter(MeasurementParameter.PMTUSLUMHighVoltageMonitor))[0] * high_voltage_monitor_fs

            PyLogger.logger.info(f"USLUM High Voltage: 1.0, high_voltage: {high_voltage:.0f} V, high_voltage_monitor: {high_voltage_monitor:.0f} V")

            if high_voltage < 1350 or high_voltage > 1650 or high_voltage_monitor < 1350 or high_voltage_monitor > 1650:
                usl_test_error = True
                await send_gc_msg(f"FAILED  - USLUM High Voltage = 1500 V")
            else:
                await send_gc_msg(f"PASSED  - USLUM High Voltage = 1500 V")

            if GlobalVar.get_stop_gc():
                return f"usl_test stopped by user"

        finally:
            await meas_endpoint.SetParameter(MeasurementParameter.PMTUSLUMHighVoltageSetting, 0.0)
            await meas_endpoint.SetParameter(MeasurementParameter.PMTUSLUMHighVoltageEnable, 0)

        # PELTIER POWER CHECK ##########################################################################################
        try:
            await eef_endpoint.SetDigitalOutput(EEFDigitalOutput.HTSALPHATECENABLE, 1)

            # HTS ALPHA TEC 0 V
            await eef_endpoint.SetAnalogOutput(EEFAnalogOutput.HTSALPHATEC, 1.0)
            await asyncio.sleep(0.2)
            tec_voltage = (await eef_endpoint.GetAnalogInput(EEFAnalogInput.USLTESTERVIN0))[0] * tec_voltage_fs
            tec_current = (await eef_endpoint.GetAnalogInput(EEFAnalogInput.USLTESTERVIN3))[0] * tec_current_fs

            PyLogger.logger.info(f"HTS Alpha TEC: 0.0, tec_voltage: {tec_voltage:.3f} V, tec_current: {tec_current:.3f} A")

            if tec_voltage > 0.2:
                usl_test_error = True
                await send_gc_msg(f"FAILED  - HTS Alpha TEC Voltage = 0 V")
            else:
                await send_gc_msg(f"PASSED  - HTS Alpha TEC Voltage = 0 V")

            if GlobalVar.get_stop_gc():
                return f"usl_test stopped by user"

            # HTS ALPHA TEC 3.55 V (~1.5 A)
            await eef_endpoint.SetAnalogOutput(EEFAnalogOutput.HTSALPHATEC, 0.56)
            await asyncio.sleep(0.2)
            tec_voltage = (await eef_endpoint.GetAnalogInput(EEFAnalogInput.USLTESTERVIN0))[0] * tec_voltage_fs
            tec_current = (await eef_endpoint.GetAnalogInput(EEFAnalogInput.USLTESTERVIN3))[0] * tec_current_fs

            PyLogger.logger.info(f"HTS Alpha TEC: 0.5, tec_voltage: {tec_voltage:.3f} V, tec_current: {tec_current:.3f} A")

            if tec_voltage < 3.37 or tec_voltage > 3.73:
                usl_test_error = True
                await send_gc_msg(f"FAILED  - HTS Alpha TEC Voltage = 3.55 V")
            else:
                await send_gc_msg(f"PASSED  - HTS Alpha TEC Voltage = 3.55 V")

            if GlobalVar.get_stop_gc():
                return f"usl_test stopped by user"

            # HTS ALPHA TEC 7.17 V (~3.0 A)
            await eef_endpoint.SetAnalogOutput(EEFAnalogOutput.HTSALPHATEC, 0.12)
            await asyncio.sleep(0.2)
            tec_voltage = (await eef_endpoint.GetAnalogInput(EEFAnalogInput.USLTESTERVIN0))[0] * tec_voltage_fs
            tec_current = (await eef_endpoint.GetAnalogInput(EEFAnalogInput.USLTESTERVIN3))[0] * tec_current_fs

            PyLogger.logger.info(f"HTS Alpha TEC: 0.12, tec_voltage: {tec_voltage:.3f} V, tec_current: {tec_current:.3f} A")

            if tec_voltage < 6.81 or tec_voltage > 7.53:
                usl_test_error = True
                await send_gc_msg(f"FAILED  - HTS Alpha TEC Voltage = 7.17 V")
            else:
                await send_gc_msg(f"PASSED  - HTS Alpha TEC Voltage = 7.17 V")

            if GlobalVar.get_stop_gc():
                return f"usl_test stopped by user"

        finally:
            await eef_endpoint.SetAnalogOutput(EEFAnalogOutput.HTSALPHATEC, 1.0)
            await eef_endpoint.SetDigitalOutput(EEFDigitalOutput.HTSALPHATECENABLE, 0)

        # COUNTING MEASUREMENT TEST with INTERNAL 250MHZ CLOCK #########################################################

        result = await usl_counting_measurement_test()
        if GlobalVar.get_stop_gc():
            return f"usl_test stopped by user"

        if result != 'PASSED':
            usl_test_error = True
            await send_gc_msg(f"FAILED  - USLUM Counting Measurement Test")
        else:
            await send_gc_msg(f"PASSED  - USLUM Counting Measurement Test")

        ################################################################################################################

        if not usl_test_error:
            return f"USL test successful. Continue with next DUT."
        else:
            return f"USL test failed. Check output and log files for details."

    finally:
        await fmb_endpoint.SetDigitalOutput(FMBDigitalOutput.PSUON, 1)   # Base Tester PSUON is active low


async def plot_laser_power_scan(power, current, file_name='graph.png'):

    plt.clf()

    plt.title(f"HTS Alpha Laser Power Scan")
    plt.xlabel('Laser Power')
    plt.ylabel(f"Laser Current")
    plt.plot(power, current, color='b')

    plt.savefig(os.path.join(report_path, file_name))
    await send_gc_event('RefreshGraph', file_name=file_name)


async def set_hts_alpha_excitation(enable=True):
    if enable:
        op_id = 'hts_alpha_on'
        seq_gen = meas_seq_generator()
        seq_gen.SetSignals(OutputSignal.HTS_Alpha)
        seq_gen.Stop(0)
        meas_unit.ClearOperations()
        await meas_unit.LoadTriggerSequence(op_id, seq_gen.currSequence)
        await meas_unit.ExecuteMeasurement(op_id)
    else:
        op_id = 'hts_alpha_off'
        seq_gen = meas_seq_generator()
        seq_gen.ResetSignals(OutputSignal.HTS_Alpha)
        seq_gen.Stop(0)
        meas_unit.ClearOperations()
        await meas_unit.LoadTriggerSequence(op_id, seq_gen.currSequence)
        await meas_unit.ExecuteMeasurement(op_id)


async def usl_counting_measurement_test():

    dl_start = 0.3
    dl_stop = 1.0
    dl_step = 0.01

    dl_frequency = 0.38                     # Measure signal frequency at this dl

    frequency_reference = 125e6             # Signal frequency in Hz
    frequency_tolerance = 0.02              # 2% tolerance for relative error

    amplitude_min = dl_frequency + 0.01     # amplitude must be greater then dl_frequency
    
    noise_cps = 10.0                        # Only CPS values above this value are counted as noise
    noise_limit = 0.2                       # No noise for dl values > amplitude + noise_limit
    
    window_ms = 100

    dl_range = np.arange(dl_start, (dl_stop + 1e-6), dl_step).round(6)
    
    cps = np.zeros_like(dl_range)   # Counts Per Second
    dt = np.zeros_like(dl_range)    # Dead Time

    op_id = 'pmt3_counting_measurement'
    meas_unit.ClearOperations()
    await load_pmt3_counting_measurement(op_id, window_ms)

    PyLogger.logger.info(f"Counting Measurement Test")
    PyLogger.logger.info(f"dl    ; cps       ; dt")

    for i, dl in enumerate(dl_range):
        if GlobalVar.get_stop_gc():
            return f"usl_counting_measurement_test stopped by user"
        
        await meas_endpoint.SetParameter(MeasurementParameter.PMTUSLUMDiscriminatorLevel, dl)
        await asyncio.sleep(0.1)
        await meas_unit.ExecuteMeasurement(op_id)
        results = await meas_unit.ReadMeasurementValues(op_id)

        cps[i] = (results[0]  + (results[1]  << 32)) / window_ms * 1000.0
        dt[i]  = (results[2]  + (results[3]  << 32)) / window_ms * 1000.0

        PyLogger.logger.info(f"{dl:.3f} ; {cps[i]:9.0f} ; {dt[i]:9.0f}")

    frequency = cps[dl_range == dl_frequency][0]
    frequency_error = abs((frequency - frequency_reference) / frequency_reference)

    PyLogger.logger.info(f"frequency: {frequency / 1e6:.2f} MHz, frequency_error: {frequency_error:.2%}")

    amplitude = dl_range[cps >= (frequency_reference * (1 - frequency_tolerance))]
    amplitude = np.max(amplitude) if len(amplitude) else 0.0

    PyLogger.logger.info(f"amplitude: {amplitude:.3f} FS")

    noise_max = dl_range[cps > noise_cps]
    noise_max = np.max(noise_max) if len(noise_max) else 1.0

    PyLogger.logger.info(f"noise_max: {noise_max} FS")

    if frequency_error > frequency_tolerance or amplitude < amplitude_min or noise_max > (amplitude + noise_limit):
        return 'FAILED'
    else:
        return 'PASSED'


async def load_pmt3_counting_measurement(op_id, window_ms):

    window_us = round(window_ms * 1000)
    window_us_coarse, window_us_fine = divmod(window_us, 65536)

    us_tick_delay = 100     # 1 us

    seq_gen = meas_seq_generator()

    seq_gen.ClearResultBuffer(relative=False, dword=False, addrReg=0, addr=0)  # pmt3_cnt_lsb
    seq_gen.ClearResultBuffer(relative=False, dword=False, addrReg=0, addr=1)  # pmt3_cnt_msb
    seq_gen.ClearResultBuffer(relative=False, dword=False, addrReg=0, addr=2)  # pmt3_dt_lsb
    seq_gen.ClearResultBuffer(relative=False, dword=False, addrReg=0, addr=3)  # pmt3_dt_msb

    seq_gen.TimerWaitAndRestart(us_tick_delay)
    seq_gen.PulseCounterControl(MeasurementChannel.US_LUM, cumulative=False, resetCounter=True, resetPresetCounter=True, correctionOn=False)
    if window_us_coarse > 0:
        seq_gen.Loop(window_us_coarse)
        seq_gen.Loop(65536)
        seq_gen.TimerWaitAndRestart(us_tick_delay)
        seq_gen.PulseCounterControl(MeasurementChannel.US_LUM, cumulative=True, resetCounter=False, resetPresetCounter=True, correctionOn=True)
        seq_gen.LoopEnd()
        seq_gen.LoopEnd()
    if window_us_fine > 0:
        seq_gen.Loop(window_us_fine)
        seq_gen.TimerWaitAndRestart(us_tick_delay)
        seq_gen.PulseCounterControl(MeasurementChannel.US_LUM, cumulative=True, resetCounter=False, resetPresetCounter=True, correctionOn=True)
        seq_gen.LoopEnd()

    seq_gen.GetPulseCounterResult(MeasurementChannel.US_LUM, deadTime=False, relative=False, resetCounter=False, cumulative=True, dword=True, addrPos=0, resultPos=0)
    seq_gen.GetPulseCounterResult(MeasurementChannel.US_LUM, deadTime=True, relative=False, resetCounter=True, cumulative=True, dword=True, addrPos=0, resultPos=2)

    seq_gen.LoopEnd()
    seq_gen.Stop(0)

    meas_unit.resultAddresses[op_id] = range(0, 4)
    await meas_unit.LoadTriggerSequence(op_id, seq_gen.currSequence)

